""" Fast (numba) primitives for the model

"""
import numpy as np
import copy
from numba import njit


# MODEL OBJECTS ABOUT THE ENTRY OF WELLS ETC --------------------------------------------
@njit
def draws(d_0, d_1, g):
    return d_0 + d_1 * g


@njit
def theta(d_0, d_1, state, shares):
    draws_at_state = draws(d_0, d_1, state[0])
    return state[1:4] / (draws_at_state * shares)


@njit
def get_mnl(gamma, gamma_negative, ev, state):
    a = gamma * ev
    a = np.where(a > 0, a, -gamma * gamma_negative)

    normalizer = np.max(a)
    b = np.exp(a - normalizer)
    num = (state[1:4] * b)
    denom = num.sum(axis=1).reshape((-1, 1))
    mnl = num / denom
    return mnl


@njit
def exponential_cdf(x, scale):
    return 1 - np.exp(- scale * x)


@njit
def q_x(a_0, a_1, d_0, d_1, state, shares):
    theta_at_state = theta(d_0, d_1, state, shares)
    q = a_0 * theta_at_state * exponential_cdf(x=1 / theta_at_state, scale=a_1)

    # Note: below is the equivalent of the q_x_capital...
    q[q > 1] = 1

    return q


@njit
def q_x_capital_all(a_0, a_1, d_0, d_1, state, shares):
    theta_at_state = theta(d_0, d_1, state, shares)
    q = a_0 * exponential_cdf(x=1 / theta_at_state, scale=a_1)

    for i in [0, 1, 2]:
        if q[i] > 1 / theta_at_state[i]:
            q[i] = 1 / theta_at_state[i]

    return q


@njit
def q_x_capital(a_0, a_1, d_0, d_1, state, shares, spec):
    theta_at_state = theta(d_0, d_1, state, shares)
    q = a_0 * exponential_cdf(x=1 / theta_at_state, scale=a_1)

    for i in [0, 1, 2]:
        if q[i] > 1 / theta_at_state[i]:
            q[i] = 1 / theta_at_state[i]

    if spec == 'low':
        return q[0]
    elif spec == 'mid':
        return q[1]
    elif spec == 'high':
        return q[2]


@njit
def get_entry_prob(c, ev, mnl):
    normalizer = 0.0  # np.max((mnl * ev).sum(axis=1) - c)
    # Could also do the below by inputting the 'true' value of c rather than
    # c/30 in the params initially.
    entry_num = np.exp((mnl * ev).sum(axis=1) - c - normalizer)
    entry_denom = np.exp(-normalizer) + entry_num
    entry_prob = entry_num/entry_denom
    return entry_prob


@njit
def get_weights(gamma, gamma_negative, c, ev, state):
    mnl = get_mnl(gamma, gamma_negative, ev, state)
    entry_prob = get_entry_prob(c, ev, mnl)
    w = mnl * entry_prob.reshape((-1, 1))
    return w


@njit
def get_weights_given_entry_prob(gamma, gamma_negative, ev, state, entry_prob):
    mnl = get_mnl(gamma, gamma_negative, ev, state)
    w = mnl * entry_prob.reshape((-1, 1))
    return w


@njit
def potential_matches(x, mu_0, mu_1, sigma_0, sigma_1, weight_lambda, denom):
    norm_pdf_0 = (1 / (sigma_0 * np.sqrt(2 * np.pi))) \
                 * np.exp(-0.5 * ((x - mu_0) / sigma_0) ** 2)

    return norm_pdf_0 / denom


# GET OBJECTS ABOUT TARGETING -----------------------------------------------------------
@njit
def get_ev_targeting(state, shares, a_0, a_1, d_0, d_1, surp_array, well_target=True):

    # 2. Get the meeting probability array
    if well_target == True:
        q = q_x(a_0, a_1, d_0, d_1, state, shares)
    else:
        q = q_x_capital_all(a_0, a_1, d_0, d_1, state, shares)

    # 3. Get the expected value array
    ev = q * surp_array

    return ev


def w_times_f(mri, state, shares, params, surp_array, verbose=False, entry_prob=None,
              use_absolute_advantage=False, well_target=True):
    if type(mri) == float:
        mri = np.array([mri])
    ev = get_ev_targeting(state, shares, params['a_0'], params['a_1'],
                          params['d_0'], params['d_1'], surp_array, well_target)
    if use_absolute_advantage:
        ev = surp_array

    # Different conditions if entry_prob is specified or not (important for counterfactuals)
    if entry_prob is None:
        w = get_weights(params['gamma'], params['gamma_negative'], params['c'], ev, state)
        entry_prob = w.sum(axis=1)
    else:
        w = get_weights_given_entry_prob(params['gamma'], params['gamma_negative'], ev, state, entry_prob)

    f = potential_matches(
        mri, params['mu_0'], params['mu_1'], params['sigma_0'], params['sigma_1'],
        params['weight_lambda'], params['denom']
    )
    output = w * f.reshape((-1, 1))
    no_entry_prob_times_f = (1.0 - entry_prob) * f
    if verbose:
        return output, ev, w, f.reshape((-1, 1))
    else:
        return output, no_entry_prob_times_f, entry_prob


def update_shares(mri, delta, state, shares, params, surp_array_by_tau, verbose=False,
                  entry_prob_by_tau=None, use_absolute_advantage=False, well_target=True):

    # Target probability
    target_by_tau = dict()
    target_sum_by_tau = dict()
    share_by_spec_tau = dict()
    entry_prob_by_tau_new = dict()

    for i, tau in enumerate([2, 3, 4]):
        if entry_prob_by_tau is not None:
            entry_prob = entry_prob_by_tau[tau]
        else:
            entry_prob = None
        (
            target_by_tau[tau],
            no_entry_prob_times_f,
            entry_prob_by_tau_new[tau]
        ) = w_times_f(mri, state, shares, params, surp_array_by_tau[tau],
                      verbose=False, entry_prob=entry_prob,
                      use_absolute_advantage=use_absolute_advantage,
                      well_target=well_target)

        denom = 0
        for k, spec in enumerate(['low', 'mid', 'high']):
            target_sum_by_tau[(spec, tau)] = target_by_tau[tau][:, k].sum()
            denom += target_sum_by_tau[(spec, tau)]
        denom += no_entry_prob_times_f.sum()

        for spec in ['low', 'mid', 'high']:
            share_by_spec_tau[(spec, tau)] = target_sum_by_tau[(spec, tau)] / denom

    if verbose:
        return share_by_spec_tau, target_by_tau, denom, entry_prob_by_tau_new
    else:
        return share_by_spec_tau


#%% GET STATE UPDATING ------------------------------------------------------------------
def get_extensions(cutoff_mask_by_spec, params, state_detail):
    """ Find how many extensions will occur

    """

    # Get extensions
    extensions_by_spec_tau = dict()
    for k, spec in enumerate(['low', 'mid', 'high']):
        extensions_by_spec_tau[spec] = dict()
        for i, tau in enumerate([2, 3, 4]):
            rigs_in_final_period = copy.deepcopy(state_detail[spec][i][:, 0])

            if cutoff_mask_by_spec is None:
                extensions_by_spec_tau[spec][tau] = rigs_in_final_period * params['eta']
            else:
                extensions_by_spec_tau[spec][tau] = \
                        rigs_in_final_period * params['eta'] * cutoff_mask_by_spec[spec][i]

    return extensions_by_spec_tau


def get_state_simple(extensions_by_spec_tau, state_detail, n_rigs, g, verbose=False):
    # Then, get the state_simple AFTER extensions have been done
    state_simple = np.array([0.0, 0.0, 0.0, 0.0])

    for k, spec in enumerate(['low', 'mid', 'high']):
        prob_unemployed = (
            1 - state_detail[spec][0].sum().sum()
            - state_detail[spec][1].sum().sum()
            - state_detail[spec][2].sum().sum()
        )
        prob_no_extend = (
            state_detail[spec][0][:, 0].sum()
            + state_detail[spec][1][:, 0].sum()
            + state_detail[spec][2][:, 0].sum()
            - extensions_by_spec_tau[spec][2].sum()
            - extensions_by_spec_tau[spec][3].sum()
            - extensions_by_spec_tau[spec][4].sum()
        )
        state_simple[k + 1] = (
            prob_unemployed + prob_no_extend
        ) * n_rigs[spec]

    state_simple[0] = copy.deepcopy(g)

    if verbose:
        return state_simple, prob_unemployed, prob_no_extend
    elif not verbose:
        return state_simple


def update_state_all(cutoffs_mask, state_detail, state,
                     shares, mri_state_grid, params, surp_array_by_tau,
                     extensions_by_spec_tau, n_rigs,
                     entry_prob_by_tau=None, verbose=False,
                     use_absolute_advantage=False, well_target=True):

    # GET TARGETING -----------------------------------------------------------
    target_by_tau = dict()
    no_entry_prob_times_f_by_tau = dict()
    entry_prob_new_by_tau = dict()
    for i, tau in enumerate([2, 3, 4]):
        if entry_prob_by_tau is not None:
            entry_prob = entry_prob_by_tau[tau]
        else:
            entry_prob = None
        (
            target_by_tau[tau],
            no_entry_prob_times_f_by_tau[tau],
            entry_prob_new_by_tau[tau]
        ) = w_times_f(
            mri=mri_state_grid,
            state=state,
            shares=shares,
            params=params,
            surp_array=surp_array_by_tau[tau],
            entry_prob=entry_prob,
            use_absolute_advantage=use_absolute_advantage,
            well_target=well_target
        )

    # GET MATCHING ------------------------------------------------------------
    q_x_by_spec = dict()
    share_available_by_spec = dict()
    prob_match_by_spec_tau = dict()
    target_sum_by_spec = dict()
    for k, spec in enumerate(['low', 'mid', 'high']):
        target_with_cutoff = dict()
        target_without_cutoff = dict()
        for i, tau in enumerate([2, 3, 4]):
            target_with_cutoff[(spec, tau)] = (
                params[f'p_{tau}'] * target_by_tau[tau][:, k] * cutoffs_mask[spec][i])
            target_without_cutoff[(spec, tau)] = (
                params[f'p_{tau}'] * target_by_tau[tau][:, k])

            share_available_by_spec[spec] = state[k + 1] / n_rigs[spec]

        # Sum of targeting (without thinking about whether the match is rejected)
        target_sum_by_spec[spec] = (
            target_without_cutoff[(spec, 2)].sum()
            + target_without_cutoff[(spec, 3)].sum()
            + target_without_cutoff[(spec, 4)].sum()
        )
        if target_sum_by_spec[spec] < 1e-15:
            target_sum_by_spec[spec] = 1e-15

        q_x_by_spec[spec] = q_x_capital(params['a_0'], params['a_1'],
                                        params['d_0'], params['d_1'],
                                        state, shares, spec)

        if q_x_by_spec[spec] == np.nan:
            q_x_by_spec[spec] = 1.0

        for i, tau in enumerate([2, 3, 4]):
            # Get prob. rig matches with a contract length tau
            # prob targeted and not rejected
            # x prob. match
            # x prob. rig is available to match
            prob_match_by_spec_tau[(spec, tau)] = \
                (target_with_cutoff[(spec, tau)] / target_sum_by_spec[spec]) \
                * q_x_by_spec[spec] * share_available_by_spec[spec]

    # PUT IT ALL TOGETHER -----------------------------------------------------
    extension_prob_by_tau = dict()
    for i, tau in enumerate([2, 3, 4]):
        for k, spec in enumerate(['low', 'mid', 'high']):
            # Note that extension have not been updated in the detailed state yet...
            state_detail[spec][i] = np.roll(state_detail[spec][i], -1, axis=1)
            state_detail[spec][i][:, -1] = \
                prob_match_by_spec_tau[(spec, tau)] + extensions_by_spec_tau[spec][tau]

    return {
        'state_detail': state_detail,
        'extension_prob_by_tau': extension_prob_by_tau,
        'share_available': share_available_by_spec,
        'prob_match_by_spec_tau': prob_match_by_spec_tau,
        'target_with_cutoff': target_with_cutoff,
        'target_sum_by_spec': target_sum_by_spec
    }

# Picklable version...
@njit
def prob_extension_predictor_with_mri(coefs, state, tau, mri):
    g = state[0]
    n_l = state[1]
    n_m = state[2]
    n_h = state[3]

    x = (
        coefs[0]
        + coefs[1] * g
        + coefs[2] * n_l
        + coefs[3] * n_m
        + coefs[4] * n_h
        + coefs[5] * g * g
        + coefs[6] * g * n_l
        + coefs[7] * g * n_m
        + coefs[8] * g * n_h
        + coefs[9] * n_l * n_l
        + coefs[10] * n_l * n_m
        + coefs[11] * n_l * n_h
        + coefs[12] * n_m * n_m
        + coefs[13] * n_m * n_h
        + coefs[14] * n_h * n_h
        + coefs[15] * tau
        + coefs[16] * mri
    )

    return np.exp(x) / (1 + np.exp(x))
